require "stringio"
require_relative "polyfill/unpack1"

module Prism
  # A module responsible for deserializing parse results.
  module Serialize
    # The major version of prism that we are expecting to find in the serialized
    # strings.
    MAJOR_VERSION = 1

    # The minor version of prism that we are expecting to find in the serialized
    # strings.
    MINOR_VERSION = 6

    # The patch version of prism that we are expecting to find in the serialized
    # strings.
    PATCH_VERSION = 0

    # Deserialize the dumped output from a request to parse or parse_file.
    #
    # The formatting of the source of this method is purposeful to illustrate
    # the structure of the serialized data.
    def self.load_parse(input, serialized, freeze)
      input = input.dup
      source = Source.for(input)
      loader = Loader.new(source, serialized)

                       loader.load_header
      encoding =       loader.load_encoding
      start_line =     loader.load_varsint
      offsets =        loader.load_line_offsets(freeze)

      source.replace_start_line(start_line)
      source.replace_offsets(offsets)

      comments =       loader.load_comments(freeze)
      magic_comments = loader.load_magic_comments(freeze)
      data_loc =       loader.load_optional_location_object(freeze)
      errors =         loader.load_errors(encoding, freeze)
      warnings =       loader.load_warnings(encoding, freeze)
      cpool_base =     loader.load_uint32
      cpool_size =     loader.load_varuint

      constant_pool = ConstantPool.new(input, serialized, cpool_base, cpool_size)

      node =           loader.load_node(constant_pool, encoding, freeze)
                       loader.load_constant_pool(constant_pool)
      raise unless     loader.eof?

      result = ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, source)
      result.freeze if freeze

      input.force_encoding(encoding)

      # This is an extremely niche use-case where the file was marked as binary
      # but it contained UTF-8-encoded characters. In that case we will actually
      # put it back to UTF-8 to give the location APIs the best chance of being
      # correct.
      if !input.ascii_only? && input.encoding == Encoding::BINARY
        input.force_encoding(Encoding::UTF_8)
        input.force_encoding(Encoding::BINARY) unless input.valid_encoding?
      end

      if freeze
        input.freeze
        source.deep_freeze
      end

      result
    end

    # Deserialize the dumped output from a request to lex or lex_file.
    #
    # The formatting of the source of this method is purposeful to illustrate
    # the structure of the serialized data.
    def self.load_lex(input, serialized, freeze)
      source = Source.for(input)
      loader = Loader.new(source, serialized)

      tokens =         loader.load_tokens
      encoding =       loader.load_encoding
      start_line =     loader.load_varsint
      offsets =        loader.load_line_offsets(freeze)

      source.replace_start_line(start_line)
      source.replace_offsets(offsets)

      comments =       loader.load_comments(freeze)
      magic_comments = loader.load_magic_comments(freeze)
      data_loc =       loader.load_optional_location_object(freeze)
      errors =         loader.load_errors(encoding, freeze)
      warnings =       loader.load_warnings(encoding, freeze)
      raise unless     loader.eof?

      result = LexResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, source)

      tokens.each do |token|
        token[0].value.force_encoding(encoding)

        if freeze
          token[0].deep_freeze
          token.freeze
        end
      end

      if freeze
        source.deep_freeze
        tokens.freeze
        result.freeze
      end

      result
    end

    # Deserialize the dumped output from a request to parse_comments or
    # parse_file_comments.
    #
    # The formatting of the source of this method is purposeful to illustrate
    # the structure of the serialized data.
    def self.load_parse_comments(input, serialized, freeze)
      source = Source.for(input)
      loader = Loader.new(source, serialized)

                   loader.load_header
                   loader.load_encoding
      start_line = loader.load_varsint

      source.replace_start_line(start_line)

      result =     loader.load_comments(freeze)
      raise unless loader.eof?

      source.deep_freeze if freeze
      result
    end

    # Deserialize the dumped output from a request to parse_lex or
    # parse_lex_file.
    #
    # The formatting of the source of this method is purposeful to illustrate
    # the structure of the serialized data.
    def self.load_parse_lex(input, serialized, freeze)
      source = Source.for(input)
      loader = Loader.new(source, serialized)

      tokens =         loader.load_tokens
                       loader.load_header
      encoding =       loader.load_encoding
      start_line =     loader.load_varsint
      offsets =        loader.load_line_offsets(freeze)

      source.replace_start_line(start_line)
      source.replace_offsets(offsets)

      comments =       loader.load_comments(freeze)
      magic_comments = loader.load_magic_comments(freeze)
      data_loc =       loader.load_optional_location_object(freeze)
      errors =         loader.load_errors(encoding, freeze)
      warnings =       loader.load_warnings(encoding, freeze)
      cpool_base =     loader.load_uint32
      cpool_size =     loader.load_varuint

      constant_pool = ConstantPool.new(input, serialized, cpool_base, cpool_size)

      node =           loader.load_node(constant_pool, encoding, freeze)
                       loader.load_constant_pool(constant_pool)
      raise unless     loader.eof?

      value = [node, tokens]
      result = ParseLexResult.new(value, comments, magic_comments, data_loc, errors, warnings, source)

      tokens.each do |token|
        token[0].value.force_encoding(encoding)

        if freeze
          token[0].deep_freeze
          token.freeze
        end
      end

      if freeze
        source.deep_freeze
        tokens.freeze
        value.freeze
        result.freeze
      end

      result
    end

    class ConstantPool # :nodoc:
      attr_reader :size

      def initialize(input, serialized, base, size)
        @input = input
        @serialized = serialized
        @base = base
        @size = size
        @pool = Array.new(size, nil)
      end

      def get(index, encoding)
        @pool[index] ||=
          begin
            offset = @base + index * 8
            start = @serialized.unpack1("L", offset: offset)
            length = @serialized.unpack1("L", offset: offset + 4)

            if start.nobits?(1 << 31)
              @input.byteslice(start, length).force_encoding(encoding).to_sym
            else
              @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(encoding).to_sym
            end
          end
      end
    end

    if RUBY_ENGINE == "truffleruby"
      # StringIO is synchronized and that adds a high overhead on TruffleRuby.
      class FastStringIO # :nodoc:
        attr_accessor :pos

        def initialize(string)
          @string = string
          @pos = 0
        end

        def getbyte
          byte = @string.getbyte(@pos)
          @pos += 1
          byte
        end

        def read(n)
          slice = @string.byteslice(@pos, n)
          @pos += n
          slice
        end

        def eof?
          @pos >= @string.bytesize
        end
      end
    else
      FastStringIO = ::StringIO # :nodoc:
    end

    class Loader # :nodoc:
      attr_reader :input, :io, :source

      def initialize(source, serialized)
        @input = source.source.dup
        raise unless serialized.encoding == Encoding::BINARY
        @io = FastStringIO.new(serialized)
        @source = source
        define_load_node_lambdas if RUBY_ENGINE != "ruby"
      end

      def eof?
        io.getbyte
        io.eof?
      end

      def load_constant_pool(constant_pool)
        trailer = 0

        constant_pool.size.times do |index|
          start, length = io.read(8).unpack("L2")
          trailer += length if start.anybits?(1 << 31)
        end

        io.read(trailer)
      end

      def load_header
        raise "Invalid serialization" if io.read(5) != "PRISM"
        raise "Invalid serialization" if io.read(3).unpack("C3") != [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION]
        raise "Invalid serialization (location fields must be included but are not)" if io.getbyte != 0
      end

      def load_encoding
        encoding = Encoding.find(io.read(load_varuint))
        @input = input.force_encoding(encoding).freeze
        encoding
      end

      def load_line_offsets(freeze)
        offsets = Array.new(load_varuint) { load_varuint }
        offsets.freeze if freeze
        offsets
      end

      def load_comments(freeze)
        comments =
          Array.new(load_varuint) do
            comment =
              case load_varuint
              when 0 then InlineComment.new(load_location_object(freeze))
              when 1 then EmbDocComment.new(load_location_object(freeze))
              end

            comment.freeze if freeze
            comment
          end

        comments.freeze if freeze
        comments
      end

      def load_magic_comments(freeze)
        magic_comments =
          Array.new(load_varuint) do
            magic_comment =
              MagicComment.new(
                load_location_object(freeze),
                load_location_object(freeze)
              )

            magic_comment.freeze if freeze
            magic_comment
          end

        magic_comments.freeze if freeze
        magic_comments
      end

      DIAGNOSTIC_TYPES = [
        <%- errors.each do |error| -%>
        <%= error.name.downcase.to_sym.inspect %>,
        <%- end -%>
        <%- warnings.each do |warning| -%>
        <%= warning.name.downcase.to_sym.inspect %>,
        <%- end -%>
      ].freeze

      private_constant :DIAGNOSTIC_TYPES

      def load_error_level
        level = io.getbyte

        case level
        when 0
          :syntax
        when 1
          :argument
        when 2
          :load
        else
          raise "Unknown level: #{level}"
        end
      end

      def load_errors(encoding, freeze)
        errors =
          Array.new(load_varuint) do
            error =
              ParseError.new(
                DIAGNOSTIC_TYPES.fetch(load_varuint),
                load_embedded_string(encoding),
                load_location_object(freeze),
                load_error_level
              )

            error.freeze if freeze
            error
          end

        errors.freeze if freeze
        errors
      end

      def load_warning_level
        level = io.getbyte

        case level
        when 0
          :default
        when 1
          :verbose
        else
          raise "Unknown level: #{level}"
        end
      end

      def load_warnings(encoding, freeze)
        warnings =
          Array.new(load_varuint) do
            warning =
              ParseWarning.new(
                DIAGNOSTIC_TYPES.fetch(load_varuint),
                load_embedded_string(encoding),
                load_location_object(freeze),
                load_warning_level
              )

            warning.freeze if freeze
            warning
          end

        warnings.freeze if freeze
        warnings
      end

      def load_tokens
        tokens = []

        while (type = TOKEN_TYPES.fetch(load_varuint))
          start = load_varuint
          length = load_varuint
          lex_state = load_varuint

          location = Location.new(@source, start, length)
          token = Token.new(@source, type, location.slice, location)

          tokens << [token, lex_state]
        end

        tokens
      end

      # variable-length integer using https://en.wikipedia.org/wiki/LEB128
      # This is also what protobuf uses: https://protobuf.dev/programming-guides/encoding/#varints
      def load_varuint
        n = io.getbyte
        if n < 128
          n
        else
          n -= 128
          shift = 0
          while (b = io.getbyte) >= 128
            n += (b - 128) << (shift += 7)
          end
          n + (b << (shift + 7))
        end
      end

      def load_varsint
        n = load_varuint
        (n >> 1) ^ (-(n & 1))
      end

      def load_integer
        negative = io.getbyte != 0
        length = load_varuint

        value = 0
        length.times { |index| value |= (load_varuint << (index * 32)) }

        value = -value if negative
        value
      end

      def load_double
        io.read(8).unpack1("D")
      end

      def load_uint32
        io.read(4).unpack1("L")
      end

      def load_optional_node(constant_pool, encoding, freeze)
        if io.getbyte != 0
          io.pos -= 1
          load_node(constant_pool, encoding, freeze)
        end
      end

      def load_embedded_string(encoding)
        io.read(load_varuint).force_encoding(encoding).freeze
      end

      def load_string(encoding)
        case (type = io.getbyte)
        when 1
          input.byteslice(load_varuint, load_varuint).force_encoding(encoding).freeze
        when 2
          load_embedded_string(encoding)
        else
          raise "Unknown serialized string type: #{type}"
        end
      end

      def load_location_object(freeze)
        location = Location.new(source, load_varuint, load_varuint)
        location.freeze if freeze
        location
      end

      def load_location(freeze)
        return load_location_object(freeze) if freeze
        (load_varuint << 32) | load_varuint
      end

      def load_optional_location(freeze)
        load_location(freeze) if io.getbyte != 0
      end

      def load_optional_location_object(freeze)
        load_location_object(freeze) if io.getbyte != 0
      end

      def load_constant(constant_pool, encoding)
        index = load_varuint
        constant_pool.get(index - 1, encoding)
      end

      def load_optional_constant(constant_pool, encoding)
        index = load_varuint
        constant_pool.get(index - 1, encoding) if index != 0
      end

      if RUBY_ENGINE == "ruby"
        def load_node(constant_pool, encoding, freeze)
          type = io.getbyte
          node_id = load_varuint
          location = load_location(freeze)
          value = case type
          <%- nodes.each_with_index do |node, index| -%>
          when <%= index + 1 %> then
            <%- if node.needs_serialized_length? -%>
            load_uint32
            <%- end -%>
            <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field|
              case field
              when Prism::Template::NodeField then "load_node(constant_pool, encoding, freeze)"
              when Prism::Template::OptionalNodeField then "load_optional_node(constant_pool, encoding, freeze)"
              when Prism::Template::StringField then "load_string(encoding)"
              when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(constant_pool, encoding, freeze) }.tap { |nodes| nodes.freeze if freeze }"
              when Prism::Template::ConstantField then "load_constant(constant_pool, encoding)"
              when Prism::Template::OptionalConstantField then "load_optional_constant(constant_pool, encoding)"
              when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_constant(constant_pool, encoding) }.tap { |constants| constants.freeze if freeze }"
              when Prism::Template::LocationField then "load_location(freeze)"
              when Prism::Template::OptionalLocationField then "load_optional_location(freeze)"
              when Prism::Template::UInt8Field then "io.getbyte"
              when Prism::Template::UInt32Field then "load_varuint"
              when Prism::Template::IntegerField then "load_integer"
              when Prism::Template::DoubleField then "load_double"
              else raise
              end
            }].join(", ") -%>)
            <%- end -%>
          end

          value.freeze if freeze
          value
        end
      else
        def load_node(constant_pool, encoding, freeze)
          @load_node_lambdas[io.getbyte].call(constant_pool, encoding, freeze)
        end

        def define_load_node_lambdas
          @load_node_lambdas = [
            nil,
            <%- nodes.each do |node| -%>
            -> (constant_pool, encoding, freeze) {
              node_id = load_varuint
              location = load_location(freeze)
              <%- if node.needs_serialized_length? -%>
              load_uint32
              <%- end -%>
              value = <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field|
                case field
                when Prism::Template::NodeField then "load_node(constant_pool, encoding, freeze)"
                when Prism::Template::OptionalNodeField then "load_optional_node(constant_pool, encoding, freeze)"
                when Prism::Template::StringField then "load_string(encoding)"
                when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(constant_pool, encoding, freeze) }"
                when Prism::Template::ConstantField then "load_constant(constant_pool, encoding)"
                when Prism::Template::OptionalConstantField then "load_optional_constant(constant_pool, encoding)"
                when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_constant(constant_pool, encoding) }"
                when Prism::Template::LocationField then "load_location(freeze)"
                when Prism::Template::OptionalLocationField then "load_optional_location(freeze)"
                when Prism::Template::UInt8Field then "io.getbyte"
                when Prism::Template::UInt32Field then "load_varuint"
                when Prism::Template::IntegerField then "load_integer"
                when Prism::Template::DoubleField then "load_double"
                else raise
                end
              }].join(", ") -%>)
              value.freeze if freeze
              value
            },
            <%- end -%>
          ]
        end
      end
    end

    # The token types that can be indexed by their enum values.
    TOKEN_TYPES = [
      nil,
      <%- tokens.each do |token| -%>
      <%= token.name.to_sym.inspect %>,
      <%- end -%>
    ].freeze

    private_constant :MAJOR_VERSION, :MINOR_VERSION, :PATCH_VERSION
    private_constant :ConstantPool, :FastStringIO, :Loader, :TOKEN_TYPES
  end

  private_constant :Serialize
end
